# -*- coding: utf-8 -*-

import os
from copy import deepcopy
from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import wfdb

from ...cfg import CFG, DEFAULTS
from ...utils import EAK
from ...utils.misc import add_docstring
from ...utils.utils_data import ECGWaveForm, masks_to_waveforms
from ..base import DataBaseInfo, PhysioNetDataBase

__all__ = [
    "LUDB",
    "compute_metrics",
]


_LUDB_INFO = DataBaseInfo(
    title="""
    Lobachevsky University Electrocardiography Database
    """,
    about="""
    1. consist of 200 10-second conventional 12-lead (i, ii, iii, avr, avl, avf, v1, v2, v3, v4, v5, v6) ECG signal records, with sampling frequency 500 Hz
    2. boundaries of P, T waves and QRS complexes were manually annotated by cardiologists, and with the corresponding diagnosis
    3. annotated are 16797 P waves, 21966 QRS complexes, 19666 T waves (in total, 58429 annotated waves)
    4. distributions of data:

        - rhythm distribution:

            +-------------------------+----------------+
            | Rhythms                 | Number of ECGs |
            +=========================+================+
            | Sinus rhythm            | 143            |
            +-------------------------+----------------+
            | Sinus tachycardia       | 4              |
            +-------------------------+----------------+
            | Sinus bradycardia       | 25             |
            +-------------------------+----------------+
            | Sinus arrhythmia        | 8              |
            +-------------------------+----------------+
            | Irregular sinus rhythm  | 2              |
            +-------------------------+----------------+
            | Abnormal rhythm         | 19             |
            +-------------------------+----------------+

        - electrical axis distribution:

            +----------------------------+----------------+
            | Heart electric axis        | Number of ECGs |
            +============================+================+
            | Normal                     | 75             |
            +----------------------------+----------------+
            | Left axis deviation (LAD)  | 66             |
            +----------------------------+----------------+
            | Vertical                   | 26             |
            +----------------------------+----------------+
            | Horizontal                 | 20             |
            +----------------------------+----------------+
            | Right axis deviation (RAD) | 3              |
            +----------------------------+----------------+
            | Undetermined               | 10             |
            +----------------------------+----------------+

        - distribution of records with conduction abnomalities (totally 79):

            +------------------------------------------------+----------------+
            | Conduction abnormalities                       | Number of ECGs |
            +================================================+================+
            | Sinoatrial blockade, undetermined              | 1              |
            +------------------------------------------------+----------------+
            | I degree AV block                              | 10             |
            +------------------------------------------------+----------------+
            | III degree AV-block                            | 5              |
            +------------------------------------------------+----------------+
            | Incomplete right bundle branch block           | 29             |
            +------------------------------------------------+----------------+
            | Incomplete left bundle branch block            | 6              |
            +------------------------------------------------+----------------+
            | Left anterior hemiblock                        | 16             |
            +------------------------------------------------+----------------+
            | Complete right bundle branch block             | 4              |
            +------------------------------------------------+----------------+
            | Complete left bundle branch block              | 4              |
            +------------------------------------------------+----------------+
            | Non-specific intravintricular conduction delay | 4              |
            +------------------------------------------------+----------------+

        - distribution of records with extrasystoles (totally 35):

            +------------------------------------------------------------------+----------------+
            | Extrasystoles                                                    | Number of ECGs |
            +==================================================================+================+
            | Atrial extrasystole, undetermined                                | 2              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, low atrial                                  | 1              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, left atrial                                 | 2              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, SA-nodal extrasystole                       | 3              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, type: single PAC                            | 4              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, type: bigemini                              | 1              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, type: quadrigemini                          | 1              |
            +------------------------------------------------------------------+----------------+
            | Atrial extrasystole, type: allorhythmic pattern                  | 1              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, morphology: polymorphic                | 2              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, localisation: RVOT, anterior wall      | 3              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, localisation: RVOT, antero-septal part | 1              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, localisation: IVS, middle part         | 1              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, localisation: LVOT, LVS                | 2              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, localisation: LV, undefined            | 1              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, type: single PVC                       | 6              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, type: intercalary PVC                  | 2              |
            +------------------------------------------------------------------+----------------+
            | Ventricular extrasystole, type: couplet                          | 2              |
            +------------------------------------------------------------------+----------------+

        - distribution of records with hypertrophies (totally 253):

            +-------------------------------+----------------+
            | Hypertrophies                 | Number of ECGs |
            +===============================+================+
            | Right atrial hypertrophy      | 1              |
            +-------------------------------+----------------+
            | Left atrial hypertrophy       | 102            |
            +-------------------------------+----------------+
            | Right atrial overload         | 17             |
            +-------------------------------+----------------+
            | Left atrial overload          | 11             |
            +-------------------------------+----------------+
            | Left ventricular hypertrophy  | 108            |
            +-------------------------------+----------------+
            | Right ventricular hypertrophy | 3              |
            +-------------------------------+----------------+
            | Left ventricular overload     | 11             |
            +-------------------------------+----------------+

        - distribution of records of pacing rhythms (totally 12):

            +------------------------------+----------------+
            | Cardiac pacing               | Number of ECGs |
            +==============================+================+
            | UNIpolar atrial pacing       | 1              |
            +------------------------------+----------------+
            | UNIpolar ventricular pacing  | 6              |
            +------------------------------+----------------+
            | BIpolar ventricular pacing   | 2              |
            +------------------------------+----------------+
            | Biventricular pacing         | 1              |
            +------------------------------+----------------+
            | P-synchrony                  | 2              |
            +------------------------------+----------------+

        - distribution of records with ischemia (totally 141):

            +-------------------------------------------------------+----------------+
            | Ischemia                                              | Number of ECGs |
            +=======================================================+================+
            | STEMI: anterior wall                                  | 8              |
            +-------------------------------------------------------+----------------+
            | STEMI: lateral wall                                   | 7              |
            +-------------------------------------------------------+----------------+
            | STEMI: septal                                         | 8              |
            +-------------------------------------------------------+----------------+
            | STEMI: inferior wall                                  | 1              |
            +-------------------------------------------------------+----------------+
            | STEMI: apical                                         | 5              |
            +-------------------------------------------------------+----------------+
            | Ischemia: anterior wall                               | 5              |
            +-------------------------------------------------------+----------------+
            | Ischemia: lateral wall                                | 8              |
            +-------------------------------------------------------+----------------+
            | Ischemia: septal                                      | 4              |
            +-------------------------------------------------------+----------------+
            | Ischemia: inferior wall                               | 10             |
            +-------------------------------------------------------+----------------+
            | Ischemia: posterior wall                              | 2              |
            +-------------------------------------------------------+----------------+
            | Ischemia: apical                                      | 6              |
            +-------------------------------------------------------+----------------+
            | Scar formation: lateral wall                          | 3              |
            +-------------------------------------------------------+----------------+
            | Scar formation: septal                                | 9              |
            +-------------------------------------------------------+----------------+
            | Scar formation: inferior wall                         | 3              |
            +-------------------------------------------------------+----------------+
            | Scar formation: posterior wall                        | 6              |
            +-------------------------------------------------------+----------------+
            | Scar formation: apical                                | 5              |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: anterior wall    | 12             |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: lateral wall     | 16             |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: septal           | 5              |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: inferior wall    | 3              |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: posterior wall   | 4              |
            +-------------------------------------------------------+----------------+
            | Undefined ischemia/scar/supp.NSTEMI: apical           | 11             |
            +-------------------------------------------------------+----------------+

        - distribution of records with non-specific repolarization abnormalities (totally 85):

            +-------------------------------------------+----------------+
            | Non-specific repolarization abnormalities | Number of ECGs |
            +===========================================+================+
            | Anterior wall                             |      18        |
            +-------------------------------------------+----------------+
            | Lateral wall                              |      13        |
            +-------------------------------------------+----------------+
            | Septal                                    |      15        |
            +-------------------------------------------+----------------+
            | Inferior wall                             |      19        |
            +-------------------------------------------+----------------+
            | Posterior wall                            |      9         |
            +-------------------------------------------+----------------+
            | Apical                                    |      11        |
            +-------------------------------------------+----------------+

        - there are also 9 records with early repolarization syndrome

       there might well be records with multiple conditions.
    5. ludb.csv stores information about the subjects (gender, age, rhythm type, direction of the electrical axis of the heart, the presence of a cardiac pacemaker, etc.)
    6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
    """,
    usage=[
        "ECG wave delineation",
        "ECG arrhythmia classification",
    ],
    issues="""
    1. (version 1.0.0, fixed in version 1.0.1) ADC gain might be wrong, either `units` should be μV, or `adc_gain` should be 1000 times larger
    """,
    references=[
        "https://physionet.org/content/ludb/1.0.1/",
        "Kalyakulina, A., Yusipov, I., Moskalenko, V., Nikolskiy, A., Kozlov, A., Kosonogov, K., Zolotykh, N., & Ivanchenko, M. (2020). Lobachevsky University Electrocardiography Database (version 1.0.1).",
    ],
    doi=[
        "10.1109/ACCESS.2020.3029211",
        "10.13026/eegm-h675",
    ],
)


@add_docstring(_LUDB_INFO.format_database_docstring(), mode="prepend")
class LUDB(PhysioNetDataBase):
    """
    Parameters
    ----------
    db_dir : `path-like`, optional
        Storage path of the database.
        If not specified, data will be fetched from Physionet.
    working_dir : `path-like`, optional
        Working directory, to store intermediate files and log files.
    verbose : int, default 1
        Level of logging verbosity.
    kwargs : dict, optional
        Auxilliary key word arguments.

    """

    __name__ = "LUDB"

    def __init__(
        self,
        db_dir: Optional[Union[str, bytes, os.PathLike]] = None,
        working_dir: Optional[Union[str, bytes, os.PathLike]] = None,
        verbose: int = 1,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            db_name="ludb",
            db_dir=db_dir,
            working_dir=working_dir,
            verbose=verbose,
            **kwargs,
        )
        if self.version == "1.0.0":
            self.logger.info("Version of LUDB 1.0.0 has bugs, make sure that version 1.0.1 or higher is used")
        self.fs = 500
        self.spacing = 1000 / self.fs
        self.data_ext = "dat"
        self.all_leads = deepcopy(EAK.Standard12Leads)
        self.all_leads_lower = [ld.lower() for ld in self.all_leads]
        # Version 1.0.0 has different beat_ann_ext: [f"atr_{item}" for item in self.all_leads_lower]
        self.beat_ann_ext = [f"{item}" for item in self.all_leads_lower]

        self._all_symbols = ["(", ")", "p", "N", "t"]
        """
        this can be obtained using the following code:
        >>> data_gen = LUDB(db_dir="/home/wenhao71/data/PhysioNet/ludb/1.0.1/")
        >>> all_symbols = set()
        >>> for rec in data_gen.all_records:
        ...     for ext in data_gen.beat_ann_ext:
        ...         ann = wfdb.rdann(data_gen.get_absolute_path(rec), extension=ext)
        ...         all_symbols.update(ann.symbol)
        """
        self._symbol_to_wavename = CFG(N="qrs", p="pwave", t="twave")
        self._wavename_to_symbol = CFG({v: k for k, v in self._symbol_to_wavename.items()})
        self.class_map = CFG(p=1, N=2, t=3, i=0)  # an extra isoelectric

        self._df_subject_info = None
        self._ls_rec()

    def _ls_rec(self) -> None:
        """Find all records in the database directory
        and store them (path, metadata, etc.) in some private attributes.
        """
        super()._ls_rec()
        if (self.db_dir / "ludb.csv").is_file():
            # newly added in version 1.0.1
            self._df_subject_info = pd.read_csv(self.db_dir / "ludb.csv")
            self._df_subject_info.ID = self._df_subject_info.ID.apply(str)
            self._df_subject_info.Sex = self._df_subject_info.Sex.apply(lambda s: s.strip())
            self._df_subject_info.Age = self._df_subject_info.Age.apply(lambda s: s.strip())
        else:
            self._df_subject_info = pd.DataFrame(
                columns=[
                    "ID",
                    "Sex",
                    "Age",
                    "Rhythms",
                    "Electric axis of the heart",
                    "Conduction abnormalities",
                    "Extrasystolies",
                    "Hypertrophies",
                    "Cardiac pacing",
                    "Ischemia",
                    "Non-specific repolarization abnormalities",
                    "Other states",
                ]
            )

    def get_subject_id(self, rec: Union[str, int]) -> int:
        """Attach a unique subject ID for the record.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.

        Returns
        -------
        int
            Subject ID associated with the record.

        """
        if isinstance(rec, int):
            rec = self[rec]
        return int(rec)

    def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = None) -> Path:
        """Get the absolute path of the record.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        extension : str, optional
            Extension of the file.

        Returns
        -------
        pathlib.Path
            Absolute path of the file.

        """
        if isinstance(rec, int):
            rec = self[rec]
        if extension is not None and not extension.startswith("."):
            extension = f".{extension}"
        return self.db_dir / "data" / f"{rec}{extension or ''}"

    def load_ann(
        self,
        rec: Union[str, int],
        leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
        metadata: bool = False,
    ) -> dict:
        """Load the wave delineation, along with metadata if specified.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        leads : str or int or List[str] or List[int], optional
            The leads of the wave delineation to be loaded.
        metadata : bool, default False
            If True, metadata will be loaded from corresponding head file.

        Returns
        -------
        ann_dict : dict
            The wave delineation annotations.

        """
        if isinstance(rec, int):
            rec = self[rec]
        ann_dict = CFG()
        rec_fp = str(self.get_absolute_path(rec))

        # wave delineation annotations
        _leads = self._normalize_leads(leads)
        _ann_ext = [f"{ld.lower()}" for ld in _leads]  # for Version 1.0.0, it is f"{ld.lower()}"
        ann_dict["waves"] = CFG({ld: [] for ld in _leads})
        for ld, ext in zip(_leads, _ann_ext):
            ann = wfdb.rdann(rec_fp, extension=ext)
            df_lead_ann = pd.DataFrame()
            symbols = np.array(ann.symbol)
            peak_inds = np.where(np.isin(symbols, ["p", "N", "t"]))[0]
            df_lead_ann["peak"] = ann.sample[peak_inds]
            df_lead_ann["onset"] = np.nan
            df_lead_ann["offset"] = np.nan
            for i, row in df_lead_ann.iterrows():
                peak_idx = peak_inds[i]
                if peak_idx == 0:
                    df_lead_ann.loc[i, "onset"] = row["peak"]
                    if symbols[peak_idx + 1] == ")":
                        df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
                    else:
                        df_lead_ann.loc[i, "offset"] = row["peak"]
                elif peak_idx == len(symbols) - 1:
                    df_lead_ann.loc[i, "offset"] = row["peak"]
                    if symbols[peak_idx - 1] == "(":
                        df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
                    else:
                        df_lead_ann.loc[i, "onset"] = row["peak"]
                else:
                    if symbols[peak_idx - 1] == "(":
                        df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
                    else:
                        df_lead_ann.loc[i, "onset"] = row["peak"]
                    if symbols[peak_idx + 1] == ")":
                        df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
                    else:
                        df_lead_ann.loc[i, "offset"] = row["peak"]
            # df_lead_ann["onset"] = ann.sample[np.where(symbols=="(")[0]]
            # df_lead_ann["offset"] = ann.sample[np.where(symbols==")")[0]]

            df_lead_ann["duration"] = (df_lead_ann["offset"] - df_lead_ann["onset"]) * self.spacing

            df_lead_ann.index = symbols[peak_inds]

            for c in ["peak", "onset", "offset"]:
                df_lead_ann[c] = df_lead_ann[c].values.astype(int)

            for _, row in df_lead_ann.iterrows():
                w = ECGWaveForm(
                    name=self._symbol_to_wavename[row.name],
                    onset=int(row.onset),
                    offset=int(row.offset),
                    peak=int(row.peak),
                    duration=row.duration,
                )
                ann_dict["waves"][ld].append(w)

        if metadata:
            header_dict = self._load_header(rec)
            ann_dict.update(header_dict)

        return ann_dict

    def load_diagnoses(self, rec: Union[str, int]) -> List[str]:
        """Load diagnoses of the record.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.

        Returns
        -------
        diagnoses : List[str]
            List of diagnoses of the record.

        """
        diagnoses = self._load_header(rec)["diagnoses"]
        return diagnoses

    def load_masks(
        self,
        rec: Union[str, int],
        leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
        mask_format: str = "channel_first",
        class_map: Optional[Dict[str, int]] = None,
    ) -> np.ndarray:
        """Load the wave delineation in the form of masks.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        leads : str or int or List[str] or List[int], optional
            The leads of the wave delineation to be loaded.
        mask_format : str, default "channel_first"
            Format of the mask,
            "channel_last" (alias "lead_last"), or
            "channel_first" (alias "lead_first").
        class_map : dict, optional
            Custom class map.
            If not set, `self.class_map` will be used.

        Returns
        -------
        masks : numpy.ndarray
            The masks corresponding to the wave delineation annotations
            of the record.

        """
        if isinstance(rec, int):
            rec = self[rec]
        _class_map = CFG(class_map) if class_map is not None else self.class_map
        data = self.load_data(rec, leads=leads, data_format="channel_first")
        masks = np.full_like(data, fill_value=_class_map.i, dtype=int)
        waves = self.load_ann(rec, leads=leads, metadata=False)["waves"]
        for idx, (l, l_w) in enumerate(waves.items()):
            for w in l_w:
                masks[idx, w.onset : w.offset] = _class_map[self._wavename_to_symbol[w.name]]
        if mask_format.lower() not in [
            "channel_first",
            "lead_first",
        ]:
            masks = masks.T
        return masks

    def from_masks(
        self,
        masks: np.ndarray,
        mask_format: str = "channel_first",
        leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
        class_map: Optional[Dict[str, int]] = None,
        fs: Optional[Real] = None,
    ) -> Dict[str, List[ECGWaveForm]]:
        """Convert masks into lists of waveforms.

        Parameters
        ----------
        masks : numpy.ndarray
            Wave delineation in the form of masks,
            of shape ``(n_leads, seq_len)`` or ``(seq_len,)``.
        mask_format : str, default "channel_first"
            Format of the mask, used only when ``masks.ndim = 2``.
            One of "channel_last" (alias "lead_last"), or
            "channel_first" (alias "lead_first").
        leads : str or int or List[str] or List[int], optional
            The leads of the wave delineation to be loaded.
        class_map : dict, optional
            Custom class map.
            If not set, `self.class_map` will be used.
        fs : numbers.Real, optional
            Sampling frequency of the signal corresponding to the `masks`,
            If is None, `self.fs` will be used,
            to compute `duration` of the ECG waveforms.

        Returns
        -------
        waves : dict
            The wave delineation annotations of the record.
            Each item value of the dict is a list
            containing the :class:`ECGWaveForm` corr. to the lead (item key).

        """
        if masks.ndim == 1:
            _masks = masks[np.newaxis, ...]
        elif masks.ndim == 2:
            if mask_format.lower() not in [
                "channel_first",
                "lead_first",
            ]:
                _masks = masks.T
            else:
                _masks = masks.copy()
        else:
            raise ValueError(f"masks should be of dim 1 or 2, but got a {masks.ndim}d array")

        if leads is not None:
            _leads = self._normalize_leads(leads)
        else:
            _leads = [f"lead_{idx+1}" for idx in range(_masks.shape[0])]
        assert len(_leads) == _masks.shape[0]

        _class_map = CFG(class_map) if class_map is not None else self.class_map

        _fs = fs if fs is not None else self.fs

        waves = CFG({lead_name: [] for lead_name in _leads})
        for channel_idx, lead_name in enumerate(_leads):
            current_mask = _masks[channel_idx, ...]
            for wave_symbol, wave_number in _class_map.items():
                if wave_symbol not in [
                    "p",
                    "N",
                    "t",
                ]:
                    continue
                wave_name = self._symbol_to_wavename[wave_symbol]
                current_wave_inds = np.where(current_mask == wave_number)[0]
                if len(current_wave_inds) == 0:
                    continue
                np.where(np.diff(current_wave_inds) > 1)
                split_inds = np.where(np.diff(current_wave_inds) > 1)[0].tolist()
                split_inds = sorted(split_inds + [i + 1 for i in split_inds])
                split_inds = [0] + split_inds + [len(current_wave_inds) - 1]
                for i in range(len(split_inds) // 2):
                    itv_start = current_wave_inds[split_inds[2 * i]]
                    itv_end = current_wave_inds[split_inds[2 * i + 1]] + 1
                    w = ECGWaveForm(
                        name=wave_name,
                        onset=itv_start,
                        offset=itv_end,
                        peak=np.nan,
                        duration=1000 * (itv_end - itv_start) / _fs,  # ms
                    )
                    waves[lead_name].append(w)
            waves[lead_name].sort(key=lambda w: w.onset)
        return waves

    def _load_header(self, rec: Union[str, int]) -> dict:
        """Load header data of a record into a dict.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.

        Returns
        -------
        header_dict : dict
            The header data of the record.

        """
        header_dict = CFG({})
        rec_fp = str(self.get_absolute_path(rec))
        header_reader = wfdb.rdheader(rec_fp)
        header_dict["units"] = header_reader.units
        header_dict["baseline"] = header_reader.baseline
        header_dict["adc_gain"] = header_reader.adc_gain
        header_dict["record_fmt"] = header_reader.fmt
        try:
            header_dict["age"] = int([line for line in header_reader.comments if "<age>" in line][0].split(": ")[-1])
        except Exception:
            header_dict["age"] = np.nan
        try:
            header_dict["sex"] = [line for line in header_reader.comments if "<sex>" in line][0].split(": ")[-1]
        except Exception:
            header_dict["sex"] = ""
        d_start = [idx for idx, line in enumerate(header_reader.comments) if "<diagnoses>" in line][0] + 1
        header_dict["diagnoses"] = header_reader.comments[d_start:]
        return header_dict

    def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Sequence[str]]] = None) -> Union[dict, str]:
        """Load subject info of a record.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        fields : str or Sequence[str], optional
            Field(s) of the subject info of record.
            If not specified, all fields of the subject info will be loaded.

        Returns
        -------
        info : dict or str
            Subject info of the given fields of the record.

        """
        if isinstance(rec, int):
            rec = self[rec]
        row = self._df_subject_info[self._df_subject_info.ID == rec]
        if row.empty:
            return {}
        row = row.iloc[0]
        info = row.to_dict()
        if fields is not None:
            if isinstance(fields, str):
                assert fields in self._df_subject_info.columns, f"No field `{fields}`"
                info = info[fields]
            else:
                assert set(fields).issubset(
                    set(self._df_subject_info.columns)
                ), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}"
                info = {k: v for k, v in info.items() if k in fields}
        return info

    def plot(
        self,
        rec: Union[str, int],
        data: Optional[np.ndarray] = None,
        ticks_granularity: int = 0,
        leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
        same_range: bool = False,
        waves: Optional[ECGWaveForm] = None,
        **kwargs: Any,
    ) -> None:
        """
        Plot the signals of a record or external signals (units in μV),
        with metadata (fs, labels, tranche, etc.),
        possibly also along with wave delineations.

        Parameters
        ----------
        rec : str or int
            Record name or index of the record in :attr:`all_records`.
        data : numpy.ndarray, optional
            12-lead ECG signal to plot.
            If is not None, data of `rec` will not be used.
            This is useful when plotting filtered data
        ticks_granularity : int, default 0
            Granularity to plot axis ticks, the higher the more ticks.
            0 (no ticks) --> 1 (major ticks) --> 2 (major + minor ticks)
        leads : str or int or List[str] or List[int], optional
            The leads of the ECG signal to plot.
        same_range : bool, default False
            If True, all leads are forced to have the same y range.
        waves : ECGWaveForm, optional
            The waves (p waves, t waves, qrs complexes) to plot.
        kwargs : dict, optional
            Additional keyword arguments to pass to :func:`matplotlib.pyplot.plot`.

        TODO
        ----
        1. Slice too long records, and plot separately for each segment.
        2. Plot waves using :func:`matplotlib.pyplot.axvspan`.

        NOTE
        ----
        `Locator` of ``plt`` has default `MAXTICKS` of 1000.
        If not modifying this number, at most 40 seconds of signal could be plotted once.

        Contributors: Jeethan, and WEN Hao

        """
        if isinstance(rec, int):
            rec = self[rec]
        if "plt" not in dir():
            import matplotlib.pyplot as plt

            plt.MultipleLocator.MAXTICKS = 3000

        if data is not None:
            assert leads is not None, "`leads` must be specified when `data` is given"
            data = np.atleast_2d(data)
            _leads = self._normalize_leads(leads)
            _lead_indices = [self.all_leads.index(ld) for ld in _leads]
            assert len(_leads) == data.shape[0], "number of leads must match data"
            units = self._auto_infer_units(data)
            self.logger.info(f"input data is auto detected to have units in {units}")
            if units.lower() == "mv":
                _data = 1000 * data
            else:
                _data = data
        else:
            _leads = self._normalize_leads(leads)
            _lead_indices = [self.all_leads.index(ld) for ld in _leads]
            _data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices]

        if same_range:
            y_ranges = np.ones((_data.shape[0],)) * np.max(np.abs(_data)) + 100
        else:
            y_ranges = np.max(np.abs(_data), axis=1) + 100

        if data is None and waves is None:
            waves = self.load_ann(rec, leads=_leads)["waves"]

        pwaves = {ld: [] for ld in _leads}
        qrs = {ld: [] for ld in _leads}
        twaves = {ld: [] for ld in _leads}
        if waves is not None:
            for l_idx, l_w in waves.items():
                for wv in l_w:
                    itv = [wv.onset, wv.offset]
                    if wv.name == self._symbol_to_wavename["p"]:
                        pwaves[l_idx].append(itv)
                    elif wv.name == self._symbol_to_wavename["N"]:
                        qrs[l_idx].append(itv)
                    elif wv.name == self._symbol_to_wavename["t"]:
                        twaves[l_idx].append(itv)

        palette = {
            "pwaves": "cyan",
            "qrs": "green",
            "twaves": "yellow",
        }
        plot_alpha = 0.4

        diagnoses = self.load_diagnoses(rec)

        nb_leads = len(_leads)

        seg_len = self.fs * 25  # 25 seconds
        nb_segs = _data.shape[1] // seg_len

        t = np.arange(_data.shape[1]) / self.fs
        duration = len(t) / self.fs
        fig_sz_w = int(round(4.8 * duration))
        fig_sz_h = 6 * y_ranges / 1500
        fig, axes = plt.subplots(nb_leads, 1, sharex=True, figsize=(fig_sz_w, np.sum(fig_sz_h)))
        if nb_leads == 1:
            axes = [axes]
        for idx in range(nb_leads):
            lead_name = self.all_leads[_lead_indices[idx]]
            axes[idx].plot(
                t,
                _data[idx],
                color="black",
                linewidth="2.0",
                label=f"lead - {lead_name}",
            )
            axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
            # NOTE that `Locator` has default `MAXTICKS` equal to 1000
            if ticks_granularity >= 1:
                axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2))
                axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500))
                axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
            if ticks_granularity >= 2:
                axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04))
                axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100))
                axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
            # add extra info. to legend
            # https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib
            for d in diagnoses:
                axes[idx].plot([], [], " ", label=d)
            for w in ["pwaves", "qrs", "twaves"]:
                for itv in eval(f"{w}['{lead_name}']"):
                    axes[idx].axvspan(
                        itv[0] / self.fs,
                        itv[1] / self.fs,
                        color=palette[w],
                        alpha=plot_alpha,
                    )
            axes[idx].legend(loc="upper left")
            axes[idx].set_xlim(t[0], t[-1])
            axes[idx].set_ylim(-y_ranges[idx], y_ranges[idx])
            axes[idx].set_xlabel("Time [s]")
            axes[idx].set_ylabel("Voltage [μV]")
        plt.subplots_adjust(hspace=0.2)
        if kwargs.get("save_path", None):
            plt.savefig(kwargs["save_path"], dpi=200, bbox_inches="tight")
        else:
            plt.show()

    @property
    def database_info(self) -> DataBaseInfo:
        return _LUDB_INFO


__TOLERANCE = 150  # ms
__WaveNames = ["pwave", "qrs", "twave"]


def compute_metrics(
    truth_masks: Sequence[np.ndarray],
    pred_masks: Sequence[np.ndarray],
    class_map: Dict[str, int],
    fs: Real,
    mask_format: str = "channel_first",
) -> Dict[str, Dict[str, float]]:
    """Compute metrics for the wave delineation task.

    Compute metrics
    (sensitivity, precision, f1_score, mean error
    and standard deviation of the mean errors)
    for multiple evaluations

    Parameters
    ----------
    truth_masks : Sequence[numpy.ndarray]
        A sequence of ground truth masks,
        each of which can also hold multiple masks
        from different samples (differ by record or by lead).
    pred_masks : Sequence[numpy.ndarray]
        Predictions corresponding to `truth_masks`
    class_map : Dict[str, int]
        Class map, mapping names to waves to numbers from 0 to n_classes-1,
        the keys should contain "pwave", "qrs", "twave".
    fs : numbers.Real
        Sampling frequency of the signal corresponding to the masks,
        used to compute the duration of each waveform,
        hence the error and standard deviations of errors.
    mask_format : str, default "channel_first"
        Format of the mask, one of the following:
        "channel_last" (alias "lead_last"), or
        "channel_first" (alias "lead_first").

    Returns
    -------
    scorings : dict
        A dictionary containing the scorings of onsets and offsets
        of pwaves, qrs complexes, twaves.
        Each scoring is a dict consisting of the following metrics:
        sensitivity, precision, f1_score, mean_error, standard_deviation.

    """
    assert len(truth_masks) == len(pred_masks)
    truth_waveforms, pred_waveforms = [], []
    # compute for each element
    for tm, pm in zip(truth_masks, pred_masks):
        n_masks = tm.shape[0] if mask_format.lower() in ["channel_first", "lead_first"] else tm.shape[1]

        new_t = masks_to_waveforms(tm, class_map, fs, mask_format)
        new_t = [new_t[f"lead_{idx+1}"] for idx in range(n_masks)]  # list of list of `ECGWaveForm`s
        truth_waveforms += new_t

        new_p = masks_to_waveforms(pm, class_map, fs, mask_format)
        new_p = [new_p[f"lead_{idx+1}"] for idx in range(n_masks)]  # list of list of `ECGWaveForm`s
        pred_waveforms += new_p

    scorings = compute_metrics_waveform(truth_waveforms, pred_waveforms, fs)

    return scorings


def compute_metrics_waveform(
    truth_waveforms: Sequence[Sequence[ECGWaveForm]],
    pred_waveforms: Sequence[Sequence[ECGWaveForm]],
    fs: Real,
) -> Dict[str, Dict[str, float]]:
    """
    Compute the sensitivity, precision, f1_score, mean error
    and standard deviation of the mean errors,
    of evaluations on a multiple samples (differ by records, or leads).

    Parameters
    ----------
    truth_waveforms : Sequence[Sequence[ECGWaveForm]]
        The ground truth. Each element is a sequence of
        :class:`ECGWaveForm` from the same sample.
    pred_waveforms : Sequence[Sequence[ECGWaveForm]]
        The predictions corresponding to `truth_waveforms`.
        Each element is a sequence of :class:`ECGWaveForm`
        from the same sample.
    fs : numbers.Real
        Sampling frequency of the signal corresponding to the waveforms,
        used to compute the duration of each waveform,
        hence the error and standard deviations of errors.

    Returns
    -------
    scorings : dict
        A dictionary containing the scorings of onsets and offsets
        of pwaves, qrs complexes, twaves.
        Each scoring is a dict consisting of the following metrics:
        sensitivity, precision, f1_score, mean_error, standard_deviation.

    """
    truth_positive = CFG(
        {
            f"{wave}_{term}": 0
            for wave in [
                "pwave",
                "qrs",
                "twave",
            ]
            for term in ["onset", "offset"]
        }
    )
    false_positive = CFG(
        {
            f"{wave}_{term}": 0
            for wave in [
                "pwave",
                "qrs",
                "twave",
            ]
            for term in ["onset", "offset"]
        }
    )
    false_negative = CFG(
        {
            f"{wave}_{term}": 0
            for wave in [
                "pwave",
                "qrs",
                "twave",
            ]
            for term in ["onset", "offset"]
        }
    )
    errors = CFG(
        {
            f"{wave}_{term}": []
            for wave in [
                "pwave",
                "qrs",
                "twave",
            ]
            for term in ["onset", "offset"]
        }
    )
    # accumulating results
    for tw, pw in zip(truth_waveforms, pred_waveforms):
        s = _compute_metrics_waveform(tw, pw, fs)
        for wave in [
            "pwave",
            "qrs",
            "twave",
        ]:
            for term in ["onset", "offset"]:
                truth_positive[f"{wave}_{term}"] += s[f"{wave}_{term}"]["truth_positive"]
                false_positive[f"{wave}_{term}"] += s[f"{wave}_{term}"]["false_positive"]
                false_negative[f"{wave}_{term}"] += s[f"{wave}_{term}"]["false_negative"]
                errors[f"{wave}_{term}"] += s[f"{wave}_{term}"]["errors"]
    scorings = CFG()
    for wave in [
        "pwave",
        "qrs",
        "twave",
    ]:
        for term in ["onset", "offset"]:
            tp = truth_positive[f"{wave}_{term}"]
            fp = false_positive[f"{wave}_{term}"]
            fn = false_negative[f"{wave}_{term}"]
            err = errors[f"{wave}_{term}"]
            sensitivity = tp / (tp + fn + DEFAULTS.eps)
            precision = tp / (tp + fp + DEFAULTS.eps)
            f1_score = 2 * sensitivity * precision / (sensitivity + precision + DEFAULTS.eps)
            mean_error = np.mean(err) * 1000 / fs
            standard_deviation = np.std(err) * 1000 / fs
            scorings[f"{wave}_{term}"] = CFG(
                sensitivity=sensitivity,
                precision=precision,
                f1_score=f1_score,
                mean_error=mean_error,
                standard_deviation=standard_deviation,
            )

    return scorings


def _compute_metrics_waveform(
    truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Real
) -> Dict[str, Dict[str, float]]:
    """
    compute the sensitivity, precision, f1_score, mean error
    and standard deviation of the mean errors,
    of evaluations on a single sample (the same record, the same lead).

    Parameters
    ----------
    truths : Sequence[ECGWaveForm]
        The ground truth.
    preds : Sequence[ECGWaveForm]
        The predictions corresponding to `truths`,
    fs : numbers.Real
        Sampling frequency of the signal corresponding to the waveforms,
        used to compute the duration of each waveform,
        hence the error and standard deviations of errors.

    Returns
    -------
    scorings : dict
        A dictionary containing the scorings of onsets and offsets
        of pwaves, qrs complexes, twaves.
        Each scoring is a dict consisting of the following metrics:
        truth_positive, false_negative, false_positive, errors,
        sensitivity, precision, f1_score, mean_error, standard_deviation.

    """
    pwave_onset_truths, pwave_offset_truths, pwave_onset_preds, pwave_offset_preds = (
        [],
        [],
        [],
        [],
    )
    qrs_onset_truths, qrs_offset_truths, qrs_onset_preds, qrs_offset_preds = (
        [],
        [],
        [],
        [],
    )
    twave_onset_truths, twave_offset_truths, twave_onset_preds, twave_offset_preds = (
        [],
        [],
        [],
        [],
    )

    for item in ["truths", "preds"]:
        for w in eval(item):
            for term in ["onset", "offset"]:
                eval(f"{w.name}_{term}_{item}.append(w.{term})")

    scorings = CFG()
    for wave in [
        "pwave",
        "qrs",
        "twave",
    ]:
        for term in ["onset", "offset"]:
            (
                truth_positive,
                false_negative,
                false_positive,
                errors,
                sensitivity,
                precision,
                f1_score,
                mean_error,
                standard_deviation,
            ) = _compute_metrics_base(eval(f"{wave}_{term}_truths"), eval(f"{wave}_{term}_preds"), fs)
            scorings[f"{wave}_{term}"] = CFG(
                truth_positive=truth_positive,
                false_negative=false_negative,
                false_positive=false_positive,
                errors=errors,
                sensitivity=sensitivity,
                precision=precision,
                f1_score=f1_score,
                mean_error=mean_error,
                standard_deviation=standard_deviation,
            )
    return scorings


def _compute_metrics_base(
    truths: Sequence[Real], preds: Sequence[Real], fs: Real
) -> Tuple[int, int, int, List[float], float, float, float, float, float]:
    """The base function for computing the metrics.

    Parameters
    ----------
    truths : Sequence[Real]
        Ground truth of indices of corresponding critical points.
    preds : Sequence[Real]
        Predicted indices of corresponding critical points.
    fs : numbers.Real
        Sampling frequency of the signal corresponding to the critical points,
        used to compute the duration of each waveform,
        hence the error and standard deviations of errors.

    Returns
    -------
    tuple
        tuple of the following metrics:
        truth_positive, false_negative, false_positive, errors,
        sensitivity, precision, f1_score, mean_error, standard_deviation.

    """
    _tolerance = __TOLERANCE * fs / 1000
    _truths = np.array(truths)
    _preds = np.array(preds)
    truth_positive, false_positive, false_negative = 0, 0, 0
    errors = []
    n_included = 0
    for point in truths:
        _pred = _preds[np.where(np.abs(_preds - point) <= _tolerance)[0].tolist()]
        if len(_pred) > 0:
            truth_positive += 1
            idx = np.argmin(np.abs(_pred - point))
            errors.append(_pred[idx] - point)
        else:
            false_negative += 1
        n_included += len(_pred)

    # false_positive = len(_preds) - n_included
    false_positive = len(_preds) - truth_positive

    # print(f"""
    # truth_positive = {truth_positive}
    # false_positive = {false_positive}
    # false_negative = {false_negative}
    # """)

    # print(f"len(truths) = {len(truths)}, truth_positive + false_negative = {truth_positive + false_negative}")
    # print(f"len(preds) = {len(preds)}, truth_positive + false_positive = {truth_positive + false_positive}")

    sensitivity = truth_positive / (truth_positive + false_negative + DEFAULTS.eps)
    precision = truth_positive / (truth_positive + false_positive + DEFAULTS.eps)
    f1_score = 2 * sensitivity * precision / (sensitivity + precision + DEFAULTS.eps)
    mean_error = np.mean(errors) * 1000 / fs
    standard_deviation = np.std(errors) * 1000 / fs

    return (
        truth_positive,
        false_negative,
        false_positive,
        errors,
        sensitivity,
        precision,
        f1_score,
        mean_error,
        standard_deviation,
    )
